{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "rOvvWAVTkMR7"
      },
      "source": [
        "# Intro to Object Detection Colab\n",
        "\n",
        "Welcome to the object detection colab!  This demo will take you through the steps of running an \"out-of-the-box\" detection model on a collection of images."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "vPs64QA1Zdov"
      },
      "source": [
        "## Imports and Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "LBZ9VWZZFUCT"
      },
      "outputs": [],
      "source": [
        "!pip install -U --pre tensorflow==\"2.2.0\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "oi28cqGGFWnY"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import pathlib\n",
        "\n",
        "# Clone the tensorflow models repository if it doesn't already exist\n",
        "if \"models\" in pathlib.Path.cwd().parts:\n",
        "  while \"models\" in pathlib.Path.cwd().parts:\n",
        "    os.chdir('..')\n",
        "elif not pathlib.Path('models').exists():\n",
        "  !git clone --depth 1 https://github.com/tensorflow/models"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "NwdsBdGhFanc"
      },
      "outputs": [],
      "source": [
        "# Install the Object Detection API\n",
        "%%bash\n",
        "cd models/research/\n",
        "protoc object_detection/protos/*.proto --python_out=.\n",
        "cp object_detection/packages/tf2/setup.py .\n",
        "python -m pip install ."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "yn5_uV1HLvaz"
      },
      "outputs": [],
      "source": [
        "import matplotlib\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "import io\n",
        "import scipy.misc\n",
        "import numpy as np\n",
        "from six import BytesIO\n",
        "from PIL import Image, ImageDraw, ImageFont\n",
        "\n",
        "import tensorflow as tf\n",
        "\n",
        "from object_detection.utils import label_map_util\n",
        "from object_detection.utils import config_util\n",
        "from object_detection.utils import visualization_utils as viz_utils\n",
        "from object_detection.builders import model_builder\n",
        "\n",
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "IogyryF2lFBL"
      },
      "source": [
        "## Utilities"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "-y9R0Xllefec"
      },
      "outputs": [],
      "source": [
        "def load_image_into_numpy_array(path):\n",
        "  \"\"\"Load an image from file into a numpy array.\n",
        "\n",
        "  Puts image into numpy array to feed into tensorflow graph.\n",
        "  Note that by convention we put it into a numpy array with shape\n",
        "  (height, width, channels), where channels=3 for RGB.\n",
        "\n",
        "  Args:\n",
        "    path: the file path to the image\n",
        "\n",
        "  Returns:\n",
        "    uint8 numpy array with shape (img_height, img_width, 3)\n",
        "  \"\"\"\n",
        "  img_data = tf.io.gfile.GFile(path, 'rb').read()\n",
        "  image = Image.open(BytesIO(img_data))\n",
        "  (im_width, im_height) = image.size\n",
        "  return np.array(image.getdata()).reshape(\n",
        "      (im_height, im_width, 3)).astype(np.uint8)\n",
        "\n",
        "def get_keypoint_tuples(eval_config):\n",
        "  \"\"\"Return a tuple list of keypoint edges from the eval config.\n",
        "  \n",
        "  Args:\n",
        "    eval_config: an eval config containing the keypoint edges\n",
        "  \n",
        "  Returns:\n",
        "    a list of edge tuples, each in the format (start, end)\n",
        "  \"\"\"\n",
        "  tuple_list = []\n",
        "  kp_list = eval_config.keypoint_edge\n",
        "  for edge in kp_list:\n",
        "    tuple_list.append((edge.start, edge.end))\n",
        "  return tuple_list"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "R4YjnOjME1gy"
      },
      "outputs": [],
      "source": [
        "# @title Choose the model to use, then evaluate the cell.\n",
        "MODELS = {'centernet_with_keypoints': 'centernet_hg104_512x512_kpts_coco17_tpu-32', 'centernet_without_keypoints': 'centernet_hg104_512x512_coco17_tpu-8'}\n",
        "\n",
        "model_display_name = 'centernet_with_keypoints' # @param ['centernet_with_keypoints', 'centernet_without_keypoints']\n",
        "model_name = MODELS[model_display_name]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "6917xnUSlp9x"
      },
      "source": [
        "### Build a detection model and load pre-trained model weights\n",
        "\n",
        "This sometimes takes a little while, please be patient!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "ctPavqlyPuU_"
      },
      "outputs": [],
      "source": [
        "# Download the checkpoint and put it into models/research/object_detection/test_data/\n",
        "\n",
        "if model_display_name == 'centernet_with_keypoints':\n",
        "  !wget http://download.tensorflow.org/models/object_detection/tf2/20200711/centernet_hg104_512x512_kpts_coco17_tpu-32.tar.gz\n",
        "  !tar -xf centernet_hg104_512x512_kpts_coco17_tpu-32.tar.gz\n",
        "  !mv centernet_hg104_512x512_kpts_coco17_tpu-32/checkpoint models/research/object_detection/test_data/\n",
        "else:\n",
        "  !wget http://download.tensorflow.org/models/object_detection/tf2/20200711/centernet_hg104_512x512_coco17_tpu-8.tar.gz\n",
        "  !tar -xf centernet_hg104_512x512_coco17_tpu-8.tar.gz\n",
        "  !mv centernet_hg104_512x512_coco17_tpu-8/checkpoint models/research/object_detection/test_data/"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "4cni4SSocvP_"
      },
      "outputs": [],
      "source": [
        "pipeline_config = os.path.join('models/research/object_detection/configs/tf2/',\n",
        "                                model_name + '.config')\n",
        "model_dir = 'models/research/object_detection/test_data/checkpoint/'\n",
        "\n",
        "# Load pipeline config and build a detection model\n",
        "configs = config_util.get_configs_from_pipeline_file(pipeline_config)\n",
        "model_config = configs['model']\n",
        "detection_model = model_builder.build(\n",
        "      model_config=model_config, is_training=False)\n",
        "\n",
        "# Restore checkpoint\n",
        "ckpt = tf.compat.v2.train.Checkpoint(\n",
        "      model=detection_model)\n",
        "ckpt.restore(os.path.join(model_dir, 'ckpt-0')).expect_partial()\n",
        "\n",
        "def get_model_detection_function(model):\n",
        "  \"\"\"Get a tf.function for detection.\"\"\"\n",
        "\n",
        "  @tf.function\n",
        "  def detect_fn(image):\n",
        "    \"\"\"Detect objects in image.\"\"\"\n",
        "\n",
        "    image, shapes = model.preprocess(image)\n",
        "    prediction_dict = model.predict(image, shapes)\n",
        "    detections = model.postprocess(prediction_dict, shapes)\n",
        "\n",
        "    return detections, prediction_dict, tf.reshape(shapes, [-1])\n",
        "\n",
        "  return detect_fn\n",
        "\n",
        "detect_fn = get_model_detection_function(detection_model)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "NKtD0IeclbL5"
      },
      "source": [
        "# Load label map data (for plotting).\n",
        "\n",
        "Label maps correspond index numbers to category names, so that when our convolution network predicts `5`, we know that this corresponds to `airplane`.  Here we use internal utility functions, but anything that returns a dictionary mapping integers to appropriate string labels would be fine."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "5mucYUS6exUJ"
      },
      "outputs": [],
      "source": [
        "label_map_path = configs['eval_input_config'].label_map_path\n",
        "label_map = label_map_util.load_labelmap(label_map_path)\n",
        "categories = label_map_util.convert_label_map_to_categories(\n",
        "    label_map,\n",
        "    max_num_classes=label_map_util.get_max_label_map_index(label_map),\n",
        "    use_display_name=True)\n",
        "category_index = label_map_util.create_category_index(categories)\n",
        "label_map_dict = label_map_util.get_label_map_dict(label_map, use_display_name=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "RLusV1o-mAx8"
      },
      "source": [
        "### Putting everything together!\n",
        "\n",
        "Run the below code which loads an image, runs it through the detection model and visualizes the detection results, including the keypoints.\n",
        "\n",
        "Note that this will take a long time (several minutes) the first time you run this code due to tf.function's trace-compilation --- on subsequent runs (e.g. on new images), things will be faster.\n",
        "\n",
        "Here are some simple things to try out if you are curious:\n",
        "* Try running inference on your own images (local paths work)\n",
        "* Modify some of the input images and see if detection still works.  Some simple things to try out here (just uncomment the relevant portions of code) include flipping the image horizontally, or converting to grayscale (note that we still expect the input image to have 3 channels).\n",
        "* Print out `detections['detection_boxes']` and try to match the box locations to the boxes in the image.  Notice that coordinates are given in normalized form (i.e., in the interval [0, 1]).\n",
        "* Set min_score_thresh to other values (between 0 and 1) to allow more detections in or to filter out more detections.\n",
        "\n",
        "Note that you can run this cell repeatedly without rerunning earlier cells.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "vr_Fux-gfaG9"
      },
      "outputs": [],
      "source": [
        "image_dir = 'models/research/object_detection/test_images/'\n",
        "image_path = os.path.join(image_dir, 'image2.jpg')\n",
        "image_np = load_image_into_numpy_array(image_path)\n",
        "\n",
        "# Things to try:\n",
        "# Flip horizontally\n",
        "# image_np = np.fliplr(image_np).copy()\n",
        "\n",
        "# Convert image to grayscale\n",
        "# image_np = np.tile(\n",
        "#     np.mean(image_np, 2, keepdims=True), (1, 1, 3)).astype(np.uint8)\n",
        "\n",
        "input_tensor = tf.convert_to_tensor(\n",
        "    np.expand_dims(image_np, 0), dtype=tf.float32)\n",
        "detections, predictions_dict, shapes = detect_fn(input_tensor)\n",
        "\n",
        "label_id_offset = 1\n",
        "image_np_with_detections = image_np.copy()\n",
        "\n",
        "# Use keypoints if available in detections\n",
        "keypoints, keypoint_scores = None, None\n",
        "if 'detection_keypoints' in detections:\n",
        "  keypoints = detections['detection_keypoints'][0].numpy()\n",
        "  keypoint_scores = detections['detection_keypoint_scores'][0].numpy()\n",
        "\n",
        "viz_utils.visualize_boxes_and_labels_on_image_array(\n",
        "      image_np_with_detections,\n",
        "      detections['detection_boxes'][0].numpy(),\n",
        "      (detections['detection_classes'][0].numpy() + label_id_offset).astype(int),\n",
        "      detections['detection_scores'][0].numpy(),\n",
        "      category_index,\n",
        "      use_normalized_coordinates=True,\n",
        "      max_boxes_to_draw=200,\n",
        "      min_score_thresh=.30,\n",
        "      agnostic_mode=False,\n",
        "      keypoints=keypoints,\n",
        "      keypoint_scores=keypoint_scores,\n",
        "      keypoint_edges=get_keypoint_tuples(configs['eval_config']))\n",
        "\n",
        "plt.figure(figsize=(12,16))\n",
        "plt.imshow(image_np_with_detections)\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "lYnOxprty3TD"
      },
      "source": [
        "## Digging into the model's intermediate predictions\n",
        "\n",
        "For this part we will assume that the detection model is a CenterNet model following Zhou et al (https://arxiv.org/abs/1904.07850).  And more specifically, we will assume that `detection_model` is of type `meta_architectures.center_net_meta_arch.CenterNetMetaArch`.\n",
        "\n",
        "As one of its intermediate predictions, CenterNet produces a heatmap of box centers for each class (for example, it will produce a heatmap whose size is proportional to that of the image that lights up at the center of each, e.g., \"zebra\"). In the following, we will visualize these intermediate class center heatmap predictions."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "xBgYgSGMhHVi"
      },
      "outputs": [],
      "source": [
        "if detection_model.__class__.__name__ != 'CenterNetMetaArch':\n",
        "  raise AssertionError('The meta-architecture for this section '\n",
        "  'is assumed to be CenterNetMetaArch!')\n",
        "\n",
        "def get_heatmap(predictions_dict, class_name):\n",
        "  \"\"\"Grabs class center logits and apply inverse logit transform.\n",
        "\n",
        "  Args:\n",
        "    predictions_dict: dictionary of tensors containing a `object_center`\n",
        "      field of shape [1, heatmap_width, heatmap_height, num_classes]\n",
        "    class_name: string name of category (e.g., `horse`)\n",
        "\n",
        "  Returns:\n",
        "    heatmap: 2d Tensor heatmap representing heatmap of centers for a given class\n",
        "      (For CenterNet, this is 128x128 or 256x256) with values in [0,1]\n",
        "  \"\"\"\n",
        "  class_index = label_map_dict[class_name]\n",
        "  class_center_logits = predictions_dict['object_center'][0]\n",
        "  class_center_logits = class_center_logits[0][\n",
        "    :, :, class_index - label_id_offset]\n",
        "  heatmap = tf.exp(class_center_logits) / (tf.exp(class_center_logits) + 1)\n",
        "  return heatmap\n",
        "\n",
        "def unpad_heatmap(heatmap, image_np):\n",
        "  \"\"\"Reshapes/unpads heatmap appropriately.\n",
        "\n",
        "  Reshapes/unpads heatmap appropriately to match image_np.\n",
        "\n",
        "  Args:\n",
        "    heatmap: Output of `get_heatmap`, a 2d Tensor\n",
        "    image_np: uint8 numpy array with shape (img_height, img_width, 3).  Note\n",
        "      that due to padding, the relationship between img_height and img_width\n",
        "      might not be a simple scaling.\n",
        "\n",
        "  Returns:\n",
        "    resized_heatmap_unpadded: a resized heatmap (2d Tensor) that is the same\n",
        "      size as `image_np`\n",
        "  \"\"\"\n",
        "  heatmap = tf.tile(tf.expand_dims(heatmap, 2), [1, 1, 3]) * 255\n",
        "  pre_strided_size = detection_model._stride * heatmap.shape[0]\n",
        "  resized_heatmap = tf.image.resize(\n",
        "      heatmap, [pre_strided_size, pre_strided_size],\n",
        "      method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
        "  resized_heatmap_unpadded = tf.slice(resized_heatmap, begin=[0,0,0], size=shapes)\n",
        "  return tf.image.resize(\n",
        "      resized_heatmap_unpadded,\n",
        "      [image_np.shape[0], image_np.shape[1]],\n",
        "      method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)[:,:,0]\n",
        "\n",
        "\n",
        "class_name = 'kite'\n",
        "heatmap = get_heatmap(predictions_dict, class_name)\n",
        "resized_heatmap_unpadded = unpad_heatmap(heatmap, image_np)\n",
        "plt.figure(figsize=(12,16))\n",
        "plt.imshow(image_np_with_detections)\n",
        "plt.imshow(resized_heatmap_unpadded, alpha=0.7,vmin=0, vmax=160, cmap='viridis')\n",
        "plt.title('Object center heatmap (class: ' + class_name + ')')\n",
        "plt.show()\n",
        "\n",
        "class_name = 'person'\n",
        "heatmap = get_heatmap(predictions_dict, class_name)\n",
        "resized_heatmap_unpadded = unpad_heatmap(heatmap, image_np)\n",
        "plt.figure(figsize=(12,16))\n",
        "plt.imshow(image_np_with_detections)\n",
        "plt.imshow(resized_heatmap_unpadded, alpha=0.7,vmin=0, vmax=160, cmap='viridis')\n",
        "plt.title('Object center heatmap (class: ' + class_name + ')')\n",
        "plt.show()"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "inference_tf2_colab.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
